# Copyright 2020 Nagoya University (Tomoki Hayashi)
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)

import torch

from Architectures.GeneralLayers.Attention import MultiHeadedAttention as BaseMultiHeadedAttention


class GSTStyleEncoder(torch.nn.Module):
    """Style encoder.
    This module is style encoder introduced in `Style Tokens: Unsupervised Style
    Modeling, Control and Transfer in End-to-End Speech Synthesis`.
    .. _`Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End
        Speech Synthesis`: https://arxiv.org/abs/1803.09017
    Args:
        idim (int, optional): Dimension of the input features.
        gst_tokens (int, optional): The number of GST embeddings.
        gst_token_dim (int, optional): Dimension of each GST embedding.
        gst_heads (int, optional): The number of heads in GST multihead attention.
        conv_layers (int, optional): The number of conv layers in the reference encoder.
        conv_chans_list: (Sequence[int], optional):
            List of the number of channels of conv layers in the reference encoder.
        conv_kernel_size (int, optional):
            Kernel size of conv layers in the reference encoder.
        conv_stride (int, optional):
            Stride size of conv layers in the reference encoder.
        gst_layers (int, optional): The number of GRU layers in the reference encoder.
        gst_units (int, optional): The number of GRU units in the reference encoder.
    """

    def __init__(
        self,
        idim: int = 128,
        gst_tokens: int = 512,  # adaspeech suggests to use many more "basis vectors", but I believe that this is already sufficient
        gst_token_dim: int = 64,
        gst_heads: int = 8,
        conv_layers: int = 8,
        conv_chans_list=(32, 32, 64, 64, 128, 128, 256, 256),
        conv_kernel_size: int = 3,
        conv_stride: int = 2,
        gst_layers: int = 2,
        gst_units: int = 256,
    ):
        """Initialize global style encoder module."""
        super(GSTStyleEncoder, self).__init__()

        self.num_tokens = gst_tokens
        self.ref_enc = ReferenceEncoder(idim=idim,
                                        conv_layers=conv_layers,
                                        conv_chans_list=conv_chans_list,
                                        conv_kernel_size=conv_kernel_size,
                                        conv_stride=conv_stride,
                                        gst_layers=gst_layers,
                                        gst_units=gst_units, )
        self.stl = StyleTokenLayer(ref_embed_dim=gst_units,
                                   gst_tokens=gst_tokens,
                                   gst_token_dim=gst_token_dim,
                                   gst_heads=gst_heads, )

    def forward(self, speech):
        """Calculate forward propagation.
        Args:
            speech (Tensor): Batch of padded target features (B, Lmax, odim).
        Returns:
            Tensor: Style token embeddings (B, token_dim).
        """
        ref_embs = self.ref_enc(speech)
        style_embs = self.stl(ref_embs)

        return style_embs

    def calculate_ada4_regularization_loss(self):
        losses = list()
        for emb1_index in range(self.num_tokens):
            for emb2_index in range(emb1_index + 1, self.num_tokens):
                if emb1_index != emb2_index:
                    losses.append(torch.nn.functional.cosine_similarity(self.stl.gst_embs[emb1_index],
                                                                        self.stl.gst_embs[emb2_index], dim=0))
        return sum(losses)


class ReferenceEncoder(torch.nn.Module):
    """Reference encoder module.
    This module is reference encoder introduced in `Style Tokens: Unsupervised Style
    Modeling, Control and Transfer in End-to-End Speech Synthesis`.
    .. _`Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End
        Speech Synthesis`: https://arxiv.org/abs/1803.09017
    Args:
        idim (int, optional): Dimension of the input features.
        conv_layers (int, optional): The number of conv layers in the reference encoder.
        conv_chans_list: (Sequence[int], optional):
            List of the number of channels of conv layers in the reference encoder.
        conv_kernel_size (int, optional):
            Kernel size of conv layers in the reference encoder.
        conv_stride (int, optional):
            Stride size of conv layers in the reference encoder.
        gst_layers (int, optional): The number of GRU layers in the reference encoder.
        gst_units (int, optional): The number of GRU units in the reference encoder.
    """

    def __init__(
        self,
        idim=80,
        conv_layers: int = 6,
        conv_chans_list=(32, 32, 64, 64, 128, 128),
        conv_kernel_size: int = 3,
        conv_stride: int = 2,
        gst_layers: int = 1,
        gst_units: int = 128,
    ):
        """Initialize reference encoder module."""
        super(ReferenceEncoder, self).__init__()

        # check hyperparameters are valid
        assert conv_kernel_size % 2 == 1, "kernel size must be odd."
        assert (
                len(conv_chans_list) == conv_layers), "the number of conv layers and length of channels list must be the same."

        convs = []
        padding = (conv_kernel_size - 1) // 2
        for i in range(conv_layers):
            conv_in_chans = 1 if i == 0 else conv_chans_list[i - 1]
            conv_out_chans = conv_chans_list[i]
            convs += [torch.nn.Conv2d(conv_in_chans,
                                      conv_out_chans,
                                      kernel_size=conv_kernel_size,
                                      stride=conv_stride,
                                      padding=padding,
                                      # Do not use bias due to the following batch norm
                                      bias=False, ),
                      torch.nn.BatchNorm2d(conv_out_chans),
                      torch.nn.ReLU(inplace=True), ]
        self.convs = torch.nn.Sequential(*convs)

        self.conv_layers = conv_layers
        self.kernel_size = conv_kernel_size
        self.stride = conv_stride
        self.padding = padding

        # get the number of GRU input units
        gst_in_units = idim
        for i in range(conv_layers):
            gst_in_units = (gst_in_units - conv_kernel_size + 2 * padding) // conv_stride + 1
        gst_in_units *= conv_out_chans
        self.gst = torch.nn.GRU(gst_in_units, gst_units, gst_layers, batch_first=True)

    def forward(self, speech):
        """Calculate forward propagation.
        Args:
            speech (Tensor): Batch of padded target features (B, Lmax, idim).
        Returns:
            Tensor: Reference embedding (B, gst_units)
        """
        batch_size = speech.size(0)
        xs = speech.unsqueeze(1)  # (B, 1, Lmax, idim)
        hs = self.convs(xs).transpose(1, 2)  # (B, Lmax', conv_out_chans, idim')
        time_length = hs.size(1)
        hs = hs.contiguous().view(batch_size, time_length, -1)  # (B, Lmax', gst_units)
        self.gst.flatten_parameters()
        # pack_padded_sequence(hs, speech_lens, enforce_sorted=False, batch_first=True)
        _, ref_embs = self.gst(hs)  # (gst_layers, batch_size, gst_units)
        ref_embs = ref_embs[-1]  # (batch_size, gst_units)

        return ref_embs


class StyleTokenLayer(torch.nn.Module):
    """Style token layer module.
    This module is style token layer introduced in `Style Tokens: Unsupervised Style
    Modeling, Control and Transfer in End-to-End Speech Synthesis`.
    .. _`Style Tokens: Unsupervised Style Modeling, Control and Transfer in End-to-End
        Speech Synthesis`: https://arxiv.org/abs/1803.09017
    Args:
        ref_embed_dim (int, optional): Dimension of the input reference embedding.
        gst_tokens (int, optional): The number of GST embeddings.
        gst_token_dim (int, optional): Dimension of each GST embedding.
        gst_heads (int, optional): The number of heads in GST multihead attention.
        dropout_rate (float, optional): Dropout rate in multi-head attention.
    """

    def __init__(
        self,
        ref_embed_dim: int = 128,
        gst_tokens: int = 10,
        gst_token_dim: int = 128,
        gst_heads: int = 4,
        dropout_rate: float = 0.0,
    ):
        """Initialize style token layer module."""
        super(StyleTokenLayer, self).__init__()

        gst_embs = torch.randn(gst_tokens, gst_token_dim // gst_heads)
        self.register_parameter("gst_embs", torch.nn.Parameter(gst_embs))
        self.mha = MultiHeadedAttention(q_dim=ref_embed_dim,
                                        k_dim=gst_token_dim // gst_heads,
                                        v_dim=gst_token_dim // gst_heads,
                                        n_head=gst_heads,
                                        n_feat=gst_token_dim,
                                        dropout_rate=dropout_rate, )

    def forward(self, ref_embs):
        """Calculate forward propagation.
        Args:
            ref_embs (Tensor): Reference embeddings (B, ref_embed_dim).
        Returns:
            Tensor: Style token embeddings (B, gst_token_dim).
        """
        batch_size = ref_embs.size(0)
        # (num_tokens, token_dim) -> (batch_size, num_tokens, token_dim)
        gst_embs = torch.tanh(self.gst_embs).unsqueeze(0).expand(batch_size, -1, -1)
        # NOTE(kan-bayashi): Shoule we apply Tanh?
        ref_embs = ref_embs.unsqueeze(1)  # (batch_size, 1 ,ref_embed_dim)
        style_embs = self.mha(ref_embs, gst_embs, gst_embs, None)

        return style_embs.squeeze(1)


class MultiHeadedAttention(BaseMultiHeadedAttention):
    """Multi head attention module with different input dimension."""

    def __init__(self, q_dim, k_dim, v_dim, n_head, n_feat, dropout_rate=0.0):
        """Initialize multi head attention module."""
        # NOTE(kan-bayashi): Do not use super().__init__() here since we want to
        #   overwrite BaseMultiHeadedAttention.__init__() method.
        torch.nn.Module.__init__(self)
        assert n_feat % n_head == 0
        # We assume d_v always equals d_k
        self.d_k = n_feat // n_head
        self.h = n_head
        self.linear_q = torch.nn.Linear(q_dim, n_feat)
        self.linear_k = torch.nn.Linear(k_dim, n_feat)
        self.linear_v = torch.nn.Linear(v_dim, n_feat)
        self.linear_out = torch.nn.Linear(n_feat, n_feat)
        self.attn = None
        self.dropout = torch.nn.Dropout(p=dropout_rate)
